#!/usr/bin/env python3

import math
from typing import Type

import torch
import torch.nn as nn

from modern_hopfield_attention.layers import Mlp, ModernHopfieldAttention


class Block(nn.Module):
    def __init__(
        self,
        attn_alpha: float,
        skip_alpha: float,
        dim: int,
        num_heads: int,
        num_tokens: int | None = None,
        mlp_ratio: float = 4.0,
        qk_norm: bool = False,
        proj_drop: float = 0.0,
        attn_drop: float = 0.0,
        act_layer: Type[nn.Module] = nn.GELU,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        mlp_layer: Type[nn.Module] = Mlp,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)

        # layer
        self.norm1 = norm_layer(dim)
        self.attn = ModernHopfieldAttention(
            dim=dim,
            num_heads=num_heads,
            attn_alpha=attn_alpha,
            skip_alpha=skip_alpha,
            causal=True,
            qkv_bias=True,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
            num_tokens=num_tokens,
        )
        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(mlp_ratio * dim),
            act_layer=act_layer,
        )

    def forward(
        self, x: torch.Tensor, h: torch.Tensor | None = None
    ) -> tuple[torch.Tensor, torch.Tensor]:

        residual = x.clone()

        x = self.norm1(x)
        x, h = self.attn(x, h)

        x = x + residual

        # mlp layer
        x = x + self.mlp(self.norm2(x))
        return x, h


class MHAGPT2(nn.Module):

    def __init__(
        self,
        vocab_size: int,
        dim: int,
        num_heads: int,
        depth: int,
        num_tokens: int,
        attn_alpha: float,
        skip_alpha: float,
        dropout: float = 0.0,
        mlp_ratio: float = 4.0,
        qk_norm: bool = False,
        proj_drop: float = 0.0,
        attn_drop: float = 0.0,
        act_layer: Type[nn.Module] = nn.GELU,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        mlp_layer: Type[nn.Module] = Mlp,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        # layer
        self.transformer = nn.ModuleDict(
            dict(
                token_embedding=nn.Embedding(vocab_size, dim),
                position_embedding=nn.Embedding(
                    num_tokens,
                    dim,
                ),
                drop=nn.Dropout(dropout),
                blocks=nn.ModuleDict(
                    {
                        f"{i}": Block(
                            attn_alpha=attn_alpha,
                            skip_alpha=skip_alpha,
                            dim=dim,
                            num_heads=num_heads,
                            num_tokens=num_tokens,
                            mlp_ratio=mlp_ratio,
                            qk_norm=qk_norm,
                            proj_drop=proj_drop,
                            attn_drop=attn_drop,
                            act_layer=act_layer,
                            norm_layer=norm_layer,
                            mlp_layer=mlp_layer,
                        )
                        for i in range(depth)
                    }
                ),
                layer_norm=nn.LayerNorm(dim),
            )
        )
        self.head = nn.Linear(dim, vocab_size, bias=False)

        ## weight tye
        self.transformer.token_embedding.weight = self.head.weight

        ## init all weights
        self.apply(self._init_weights)

        ## apply special scaled init to the residual projections, per GPT-2 paper
        for name, param in self.named_parameters():
            if name.endswith("proj.weight"):
                torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * depth))

        # buffer
        self.register_buffer("pos", torch.arange(0, num_tokens, dtype=torch.long))

    def _init_weights(self, module: Type[nn.Module]) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, n = x.size()
        # token_embedding
        x = self.transformer.token_embedding(x)

        # position_embedding
        position = torch.arange(0, n, dtype=torch.long, device=x.device)
        position = self.transformer.position_embedding(position)

        x = x + position
        x = self.transformer.drop(x + position)
        h = None
        for block in self.transformer.blocks.values():
            x, h = block(x, h)

        x = self.transformer.layer_norm(x)

        # head
        x = self.head(x)

        return x

    def register_hooks(self) -> None:
        self.hook_input = list()

        def hook_fn(module, input, output) -> None:
            if isinstance(module, ModernHopfieldAttention):
                self.hook_input.append(input[0].detach().cpu())

        for block in self.transformer.blocks.values():
            block.attn.register_forward_hook(hook_fn)

    def clear_hooks(self) -> None:
        self.hook_input = list()
